#include "invert_index_query.h"

InvertIndex::InvertIndex()
    : IndexTableBase()
{ }

InvertIndex::~InvertIndex()
{ }

int InvertIndex::InitServer(
    const SDTCHost &dtchost,
    const std::string& bindAddr)
{
    std::string _MasterAddress = bindAddr;
    log_info("master address is  [%s]", _MasterAddress.c_str());

    server_.StringKey();
    server_.SetTableName(dtchost.szTablename.c_str());
    server_.SetAddress(_MasterAddress.c_str());
    server_.SetMTimeout(300);

    int ret;
    if ((ret = server_.Ping()) != 0 && ret != -DTC::EC_TABLE_MISMATCH) {
        log_error("ping slave[%s] failed, err: %d", _MasterAddress.c_str(), ret);
        return -1;
    }

    return ret;
}

bool InvertIndex::DocValid(
    uint32_t appid,
    const std::vector<IndexInfo>& vecs,
    bool need_version,
    std::map<std::string, uint32_t>& valid_version,
    hash_string_map& doc_content_map)
{
    int numbers = 32;
    int docSize = vecs.size();
    int count = docSize / numbers;
    int remain = docSize % numbers;
    std::vector<DocVersionInfo> docVersionInfo;

    for (int index = 0; index < count; index++){
        int left = index * numbers;
        int right = (index + 1) * numbers;
        if (!getSnapshotExecute(left, right, appid, vecs, docVersionInfo))
            return false;
    }

    if (!getSnapshotExecute(docSize - remain, docSize, appid, vecs, docVersionInfo)) {
        return false;
    }

    for (size_t i = 0; i < vecs.size(); i++){
        std::string doc_id = vecs[i].doc_id;
        uint32_t doc_version = vecs[i].doc_version;

        for (size_t j = 0; j < docVersionInfo.size(); j++)
        {
            if ((doc_id == docVersionInfo[j].doc_id) &&
             (doc_version == docVersionInfo[j].doc_version)){
                ResultContext::Instance()->SetValidDocs(doc_id);

                doc_content_map.insert(std::move(
                    std::make_pair(doc_id, docVersionInfo[j].content)));
                
                if(need_version){
                    valid_version.insert(std::move(
                        std::make_pair(doc_id, doc_version)));
                }
                break;
            }
        }
    }
    return true;
}

bool InvertIndex::TopDocValid(
    uint32_t appid,
    std::vector<TopDocInfo>& no_filter_docs,
    std::vector<TopDocInfo>& doc_info)
{
    int numbers = 32; //DTC批量查找的上限为32个
    int docSize = no_filter_docs.size();
    int count = docSize / numbers;
    int remain = docSize % numbers;
    std::vector<DocVersionInfo> docVersionInfo;

    for (int index = 0; index < count; index++) 
    {
        int left = index * numbers;
        int right = (index + 1) * numbers;
        if (!getTopSnapshotExecute(left, right, appid, no_filter_docs, docVersionInfo))
            return false;
    }

    if (!getTopSnapshotExecute(docSize-remain, docSize, appid, no_filter_docs, docVersionInfo)) {
        return false;
    }

    for (size_t i = 0; i < no_filter_docs.size(); i++)
    {
        TopDocInfo info = no_filter_docs[i];
        for (size_t j = 0; j < docVersionInfo.size(); j++)
        {
            if ((info.doc_id == docVersionInfo[j].doc_id) && (info.doc_version == docVersionInfo[j].doc_version)) {
                doc_info.push_back(info);
                break;
            }
        }
    }

    return true;
}

bool InvertIndex::GetDocInfo(
    uint32_t appid,
    const std::string& word,
    uint32_t field_id,
    std::vector<IndexInfo>& doc_info)
{ 
     DTC::Server* dtcServer = &server_;
    if (NULL == dtcServer) {
        log_error("dtcServer is null !");
        return false;
    }

    DTC::GetRequest getReq(dtcServer);
    int ret = getReq.SetKey(word.c_str() , word.length());
    if (field_id != 0 && field_id != INT_MAX){
        ret = getReq.EQ("field", field_id);
    }
    ret = getReq.EQ("start_time", 0);
    ret = getReq.EQ("end_time", 0);
    ret = getReq.Need("doc_id");
    ret = getReq.Need("doc_version");
    ret = getReq.Need("field");
    ret = getReq.Need("word_freq");
    ret = getReq.Need("location");
    ret = getReq.Need("created_time");

    DTC::Result rst;
    ret = getReq.Execute(rst);
    if (ret != 0) {
        if (ret == -110) {
            rst.Reset();
            ret = getReq.Execute(rst);
            if (ret != 0) {
                log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
                return false;
            }
        }
        else {
            log_error("get request[%s] error! errcode %d,errmsg %s, errfrom %s", word.c_str(), ret, rst.ErrorMessage(), rst.ErrorFrom());
            return false;
        }
    }

    int cnt = rst.NumRows();
    if (cnt <= 0) {
        log_debug("not find in index, key[%s]", word.c_str());
        doc_info.clear();
    }
    else {
        for (int i = 0; i < cnt; i++) {
            rst.FetchRow();
            static IndexInfo info;
            info.appid = appid;
            info.doc_id = rst.StringValue("doc_id");
            info.doc_version = rst.IntValue("doc_version");
            info.word_freq = rst.IntValue("word_freq");
            info.field = rst.IntValue("field");
            info.pos = rst.StringValue("location");
            info.created_time = rst.IntValue("created_time");

            if(doc_info.empty() || info.doc_id.compare(doc_info.back().doc_id) != 0){
                doc_info.push_back(std::move(info));
            } else if(info.doc_version > doc_info.back().doc_version){
                doc_info.back().doc_version = info.doc_version;
                doc_info.back().word_freq = info.word_freq;
                doc_info.back().field = info.field;
                doc_info.back().pos = info.pos;
                doc_info.back().created_time = info.created_time;
            }
        }

    }
    return true;
}

bool InvertIndex::GetTopDocInfo(
    uint32_t appid,
    std::string word,
    std::vector<TopDocInfo>& doc_info)
{
    log_debug("appid [%d], word[%s]", appid, word.c_str());
    int ret;
    DTC::Server* dtcServer = &server_;

    if (NULL == dtcServer) {
        log_error("dtcServer is null !");
        return false;
    }

    std::stringstream ss_key;
    ss_key<<appid;
    ss_key<<"#01#";
    ss_key<<word;

    time_t tm = time(0);
    DTC::GetRequest getReq(dtcServer);

    ret = getReq.SetKey(ss_key.str().c_str());
    ret = getReq.LE("start_time", tm);
    ret = getReq.GE("end_time", tm);
    ret = getReq.Need("doc_id");
    ret = getReq.Need("doc_version");
    ret = getReq.Need("created_time");
    ret = getReq.Need("weight");

    DTC::Result rst;
    ret = getReq.Execute(rst);
    if (ret != 0) {
        if (ret == -110) {
            rst.Reset();
            ret = getReq.Execute(rst);
            if (ret != 0) {
                log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
                return false;
            }
        }
        else {
            log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
            return false;
        }
    }
    int cnt = rst.NumRows();
    std::string doc_id;
    if (rst.NumRows() <= 0) {
        log_debug("not find in index, key[%s]", word.c_str());
        doc_info.clear();
    }
    else {
        for (int i = 0; i < cnt; i++) {
            rst.FetchRow();
            TopDocInfo info;
            info.appid = appid;
            info.doc_id = rst.StringValue("doc_id");
            info.doc_version = rst.IntValue("doc_version");
            info.created_time = rst.IntValue("created_time");
            info.weight = rst.IntValue("weight");
            doc_info.push_back(info);
        }

    }
    return true;
}

int InvertIndex::GetDocCnt(
    uint32_t appid)
{
    int doc_cnt = 10000;
    int ret;
    DTC::Server* dtcServer = &server_;

    if (NULL == dtcServer) {
        log_error("dtcServer is null !");
        return doc_cnt;
    }

    DTC::GetRequest getReq(dtcServer);

    long long search_id = appid;
    search_id <<= 32;
    search_id += 0x7FFFFFFF;
    ret = getReq.SetKey(search_id);
    ret = getReq.Need("extend");

    DTC::Result rst;
    ret = getReq.Execute(rst);
    if (ret != 0) {
        if (ret == -110) {
            ret = getReq.Execute(rst);
            if (ret != 0) {
                log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
                return doc_cnt;
            }
        }
        else {
            log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
            return doc_cnt;
        }
    }
    if (rst.NumRows() > 0) {
        rst.FetchRow();
        doc_cnt = atoi(rst.StringValue("extend"));
    }
    log_debug("doc count: %d.", doc_cnt);

    return doc_cnt;
}

bool InvertIndex::GetScoreByField(
    uint32_t appid,
    std::string doc_id,
    std::string sort_field,
    uint32_t sort_type,
    ScoreInfo &score_info)
{
    int ret = 0;
    DTC::Server* dtcServer = &server_;

    if (NULL == dtcServer) {
        log_error("dtcServer is null !");
        return false;
    }

    std::stringstream ss_key;
    ss_key << appid;
    ss_key << "#10#";
    ss_key << doc_id;

    DTC::GetRequest getReq(dtcServer);
    ret = getReq.SetKey(ss_key.str().c_str());
    ret = getReq.Need("extend");

    DTC::Result rst;
    ret = getReq.Execute(rst);
    if (ret != 0) {
        if (ret == -110) {
            rst.Reset();
            ret = getReq.Execute(rst);
            if (ret != 0) {
                log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
                return false;
            }
        }
        else {
            log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
            return false;
        }
    }

    if (rst.NumRows() <= 0) {
        log_debug("not find in index, key[%s]", ss_key.str().c_str());
    }
    else {
        rst.FetchRow();
        std::string extend = rst.StringValue("extend");
        Json::Reader r(Json::Features::strictMode());
        Json::Value recv_packet;
        int ret2 = 0;
        ret2 = r.parse(extend.c_str(), extend.c_str() + extend.length(), recv_packet);
        if (0 == ret2)
        {
            log_error("the err json is %s", extend.c_str());
            log_error("parse json error , errmsg : %s", r.getFormattedErrorMessages().c_str());
            return false;
        }

        if (recv_packet.isMember(sort_field.c_str()))
        {
            if(recv_packet[sort_field.c_str()].isUInt()){
                score_info.type = FIELDTYPE_INT;
                score_info.i = recv_packet[sort_field.c_str()].asUInt();
                score_info.score = recv_packet[sort_field.c_str()].asUInt();
            } else if(recv_packet[sort_field.c_str()].isString()){
                score_info.type = FIELDTYPE_STRING;
                score_info.str = recv_packet[sort_field.c_str()].asString();
                score_info.score = atoi(recv_packet[sort_field.c_str()].asString().c_str());
            } else if(recv_packet[sort_field.c_str()].isDouble()){
                score_info.type = FIELDTYPE_DOUBLE;
                score_info.d = recv_packet[sort_field.c_str()].asDouble();
                score_info.score = recv_packet[sort_field.c_str()].asDouble();
            } else {
                log_error("sort_field[%s] data type error.", sort_field.c_str());
                return false;
            }
        } else {
            log_error("appid[%u] sort_field[%s] invalid.", appid, sort_field.c_str());
            return false;
        }

    }

    return true;
}

bool InvertIndex::GetDocContent(
    uint32_t appid,
    const std::vector<IndexInfo>& index_infos,
    hash_string_map& doc_content)
{
    int numbers = 32; //DTC批量查找的上限为32个
    int doc_size = index_infos.size();
    int count = doc_size / numbers;
    int remain = doc_size % numbers;

    for (int index = 0; index < count; index++) 
    {
        int left = index * numbers;
        int right = (index + 1) * numbers;
        if (!getSnapshotContent(left, right, appid, index_infos ,doc_content))
            return false;
    }

    if (!getSnapshotContent(doc_size-remain, doc_size, appid, index_infos ,doc_content)) {
        return false;
    }

    return true;
}

bool InvertIndex::DocValid(
    uint32_t appid,
    std::string doc_id,
    bool &is_valid)
{
    int ret = 0;
    DTC::Server* dtc_server = &server_;
    if (NULL == dtc_server) {
        log_error("dtc_server is null !");
        return false;
    }
    DTC::GetRequest getReq(dtc_server);
    std::stringstream ss_key;
    ss_key << appid;
    ss_key << "#10#";
    ss_key << doc_id;
    ret = getReq.SetKey(ss_key.str().c_str());
    ret |= getReq.Need("doc_id");
    ret |= getReq.Need("doc_version");

    if(0 != ret)
    {
        log_error("set need field failed. appid:%d, doc_id:%s", appid, doc_id.c_str());
        return false;
    }

    DTC::Result rst;
    ret = getReq.Execute(rst);
    if(0 != ret)
    {
        log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
        return false;
     }
    int cnt = rst.NumRows();
    if(cnt <= 0)
    {
        log_debug("can not find any result. appid:%d, doc_id:%s", appid, doc_id.c_str());
        is_valid = false;
    } else {
        is_valid = true;
    }
    return true;
}

bool InvertIndex::getSnapshotContent(
    int left,
    int right,
    uint32_t appid,
    const std::vector<IndexInfo>& index_infos,
    hash_string_map& doc_content)
{
    int ret = 0;
    if (left == right) {
        return true;
    }
    DTC::Server* dtc_server = &server_;
    if (NULL == dtc_server) {
        log_error("dtc_server is null !");
        return false;
    }
    dtc_server->AddKey("key", DTC::KeyTypeString);
    DTC::GetRequest getReq(dtc_server);

    std::string docKeys;
    for (int i = left ; i < right; i++) 
    {
        std::stringstream ss_key;
        ss_key << appid;
        ss_key << "#10#";
        ss_key << index_infos[i].doc_id;
        docKeys += ss_key.str();
        ret = getReq.AddKeyValue("key", ss_key.str().c_str());
    }

    if(0 != ret)
    {
        log_error("AddKeyValue failed");
        return false;
    }
    ret |= getReq.Need("doc_id");
    ret |= getReq.Need("extend");

    if(0 != ret)
    {
        log_error("set need field failed. appid:%d, key:%s", appid, docKeys.c_str());
        return false;
    }
    DTC::Result rst;
    ret = getReq.Execute(rst);
    if(ret != 0)
    {
        log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
        return false;
     }
    int cnt = rst.NumRows();
    if(cnt <= 0)
    {
        log_debug("can not find any result. appid:%d, key:%s", appid, docKeys.c_str());
        return true;
    }

    for(int i = 0; i < cnt; ++i)
    {
        rst.FetchRow();
        std::string doc_id = rst.StringValue("doc_id");
        std::string content = rst.StringValue("extend");
        doc_content[doc_id] = content;
    }

    return true;
}

bool InvertIndex::getSnapshotExecute(
    int left, 
    int right, 
    uint32_t appid, 
    const std::vector<IndexInfo>& no_filter_docs,
    std::vector<DocVersionInfo>& docVersionInfo) 
{
    if (left == right) {
        return true;
    }

    DTC::Server* dtc_server = &server_;
    if (NULL == dtc_server) {
        log_error("dtc_server is null !");
        return false;
    }
    dtc_server->AddKey("key", DTC::KeyTypeString);

    DTC::GetRequest getReq(dtc_server);
    std::string snapKeys(std::to_string(appid));
    int ret = 0;
    for (int i = left ; i < right; i++) {
        snapKeys += "#10#";
        snapKeys += no_filter_docs[i].doc_id;
        ret = getReq.AddKeyValue("key", snapKeys.c_str());
    }

    if (0 != ret) {
        log_error("AddKeyValue failed");
        return false;
    }

    ret |= getReq.Need("doc_id");
    ret |= getReq.Need("doc_version");
    ret |= getReq.Need("extend");
    if (0 != ret) {
        log_error("set need field failed. appid:%d, doc_id:%s", appid, snapKeys.c_str());
        return false;
    }

    DTC::Result rst;
    ret = getReq.Execute(rst);
    if (ret != 0) {
        log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
        return false;
    }

    int cnt = rst.NumRows();
    if (cnt <= 0) {
        log_debug("can not find any result. appid:%d, doc_id:%s", appid, snapKeys.c_str());
        return true;
    }

    for (int i = 0; i < cnt; ++i) {
        rst.FetchRow();
        static DocVersionInfo info;
        info.doc_id = rst.StringValue("doc_id");
        info.doc_version = rst.IntValue("doc_version");
        info.content = rst.StringValue("extend");
        docVersionInfo.push_back(std::move(info));
    }

    return true;
}

bool InvertIndex::getTopSnapshotExecute(
    int left, 
    int right, 
    uint32_t appid, 
    std::vector<TopDocInfo>& no_filter_docs, 
    std::vector<DocVersionInfo>& docVersionInfo) 
{
    int ret = 0;
    if (left == right) {
        return true;
    }
    DTC::Server* dtc_server = &server_;
    if (NULL == dtc_server) {
        log_error("dtc_server is null !");
        return false;
    }
    dtc_server->AddKey("key", DTC::KeyTypeString);
    DTC::GetRequest getReq(dtc_server);

    std::string snapKeys;
    for (int i = left ; i < right; i++) 
    {
        std::stringstream ss_key;
        ss_key << appid;
        ss_key << "#11#";
        ss_key << no_filter_docs[i].doc_id;
        snapKeys += ss_key.str();
        ret = getReq.AddKeyValue("key", ss_key.str().c_str());
    }

    if(0 != ret)
    {
        log_error("AddKeyValue failed");
        return false;
    }
    ret |= getReq.Need("doc_id");
    ret |= getReq.Need("doc_version");

    if(0 != ret)
    {
        log_error("set need field failed. appid:%d, key:%s", appid, snapKeys.c_str());
        return false;
    }
    DTC::Result rst;
    ret = getReq.Execute(rst);
    if(ret != 0)
    {
        log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
        return false;
     }
    int cnt = rst.NumRows();
    if(cnt <= 0)
    {
        log_debug("can not find any result. appid:%d, key:%s", appid, snapKeys.c_str());
        return true;
    }

    for(int i = 0; i < cnt; ++i)
    {
        DocVersionInfo info;
        rst.FetchRow();
        info.doc_id = rst.StringValue("doc_id");
        info.doc_version = rst.IntValue("doc_version");
        docVersionInfo.push_back(info);
    }

    return true;
}

bool InvertIndex::GetContentByField(
    uint32_t appid,
    std::string doc_id,
    uint32_t doc_version,
    const std::vector<std::string>& fields,
    Json::Value &value)
{
    int ret = 0;
    DTC::Server* dtcServer = &server_;

    if (NULL == dtcServer) {
        log_error("dtcServer is null !");
        return false;
    }

    std::stringstream ss_key;
    ss_key << appid;
    ss_key << "#10#";
    ss_key << doc_id;

    DTC::GetRequest getReq(dtcServer);
    ret = getReq.SetKey(ss_key.str().c_str());
    if(doc_version != 0){
        getReq.EQ("doc_version", doc_version);
    }
    ret = getReq.Need("extend");

    DTC::Result rst;
    ret = getReq.Execute(rst);
    if (ret != 0) {
        if (ret == -110) {
            rst.Reset();
            ret = getReq.Execute(rst);
            if (ret != 0) {
                log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
                return false;
            }
        }
        else {
            log_error("get request error! errcode %d,errmsg %s, errfrom %s", ret, rst.ErrorMessage(), rst.ErrorFrom());
            return false;
        }
    }

    if (rst.NumRows() <= 0) {
        log_debug("not find in index, key[%s]", ss_key.str().c_str());
    }
    else {
        rst.FetchRow();
        std::string extend = rst.StringValue("extend");
        Json::Reader r(Json::Features::strictMode());
        Json::Value recv_packet;
        int ret2 = 0;
        ret2 = r.parse(extend.c_str(), extend.c_str() + extend.length(), recv_packet);
        if (0 == ret2)
        {
            log_error("parse json error [%s], errmsg : %s", extend.c_str(), r.getFormattedErrorMessages().c_str());
            return false;
        }

        for(int i = 0; i < (int)fields.size(); i++){
            if (recv_packet.isMember(fields[i].c_str()))
            {
                if(recv_packet[fields[i].c_str()].isUInt()){
                    value[fields[i].c_str()] = recv_packet[fields[i].c_str()].asUInt();
                } else if(recv_packet[fields[i].c_str()].isString()){
                    value[fields[i].c_str()] = recv_packet[fields[i].c_str()].asString();
                } else if(recv_packet[fields[i].c_str()].isDouble()){
                    value[fields[i].c_str()] = recv_packet[fields[i].c_str()].asDouble();
                } else {
                    log_error("field[%s] data type error.", fields[i].c_str());
                }
            } else {
                log_info("appid[%u] field[%s] invalid.", appid, fields[i].c_str());
            }
        }
    }
    return true;
}